package org.zjvis.datascience.common.etl;

import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.google.common.base.Joiner;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.commons.lang3.SerializationUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.zjvis.datascience.common.constant.SqlTemplate;
import org.zjvis.datascience.common.enums.ETLEnum;
import org.zjvis.datascience.common.enums.SubTypeEnum;
import org.zjvis.datascience.common.exception.BaseErrorCode;
import org.zjvis.datascience.common.exception.DataScienceException;
import org.zjvis.datascience.common.sql.SqlHelper;
import org.zjvis.datascience.common.util.ToolUtil;
import org.zjvis.datascience.common.vo.TaskVO;

/**
 * @description ETL-JOIN 多表联结类
 * @date 2021-12-27
 */
public class Join extends BaseETL {

    // private static String JOIN_SQL = "select %s from %s a %s %s b on trim(cast(a.\"%s\" as text)) %s trim(cast(b.\"%s\" as text))";
    private static String JOIN_SQL = "select row_number() over() as _record_id_, %s from %s a %s %s b on %s";

    // private static String JOIN_SQL_SAMPLE = "select %s from %s a %s %s b on trim(cast(a.\"%s\" as text)) %s trim(cast(b.\"%s\" as text)) limit %s";
    private static String JOIN_SQL_SAMPLE = "select row_number() over() as _record_id_, %s from %s a %s %s b on %s limit %s";

    private final static Logger logger = LoggerFactory.getLogger("join");

    private static final double MIN_REC_SCORE = 0.8;

    public static double getMinRecScore(){
        return MIN_REC_SCORE;
    }

    public enum JoinType {
        LEFT_JOIN("LEFT JOIN", 1),
        RIGHT_JOIN("RIGHT JOIN", 2),
        JOIN("JOIN", 3),
        OUTER_JOIN("FULL OUTER JOIN", 4),
        LEFT_NOT_JOIN("LEFT JOIN", 5),
        RIGHT_NOT_JOIN("RIGHT JOIN", 6),
        OUTER_NOT_JOIN("FULL OUTER JOIN", 7);

        private String desc;

        private int val;

        JoinType(String desc, int val) {
            this.desc = desc;
            this.val = val;
        }

        public String getDesc() {
            return desc;
        }

        int getVal() {
            return val;
        }
    }

    // join条件
    private JSONArray rules = new JSONArray();

    private String left;

    private String right;

    // 1 left join;
    private int mod;

    private String operator;

    public Join() {
        super(ETLEnum.JOIN.name(), SubTypeEnum.ETL_OPERATE.getVal(),
            SubTypeEnum.ETL_OPERATE.getDesc());
        this.maxParentNumber = 2;
    }

    public void parserConf(JSONObject conf) {
        // 处理autojoin autoJoinRecommendation
        // 初次用autoJoin推荐出来的第一个匹配条件进行join查询
        if (conf.containsKey("autoJoinRecommendation") && (!conf.containsKey("conditions")
            || conf.getJSONArray("conditions").size() == 0)) {
            // 已经触发了推荐并且conditions为空或者不存在, 根据推荐结果进行初始化
            JSONArray recommendRules = conf.getJSONArray("autoJoinRecommendation");
            int index = 0;
            if (recommendRules.size() > 0) {
                do {
                    JSONObject rule = recommendRules.getJSONObject(index);
                    double score = rule.getDouble("score");
                    if (score >= MIN_REC_SCORE) {
                        rule.put("operator", "=");
                        rules.add(rule);
                    }
                } while (false);
            }
            conf.put("conditions", SerializationUtils.clone(rules));
        } else if (conf.containsKey("conditions") && conf.getJSONArray("conditions").size() > 0) {
            // 用户配置join规则情况
            rules = conf.getJSONArray("conditions");
        }
        this.mod = conf.getInteger("mod");
    }

    private String buildConditions() {
        ArrayList<String> result = new ArrayList<>();
        int index = 0;
        for (; index < rules.size(); ++index) {
            JSONObject item = rules.getJSONObject(index);
            String leftHeaderName = item.getString("leftHeaderName");
            String rightHeaderName = item.getString("rightHeaderName");
            String operator = item.getString("operator");
            result.add(
                String.format("a.\"%s\" %s b.\"%s\"", leftHeaderName, operator, rightHeaderName));
        }
        return Joiner.on(" and ").join(result);
    }

    public String initSql(JSONObject conf, List<SqlHelper> sqlHelpers, long timeStamp,
        String engineName) {
        this.engineName = engineName;
        if (sqlHelpers.size() != 2) {
            logger.error("sqlHelper number is not equal with 2");
            return StringUtils.EMPTY;
        }
//    if (!validate(sqlHelpers.get(0), sqlHelpers.get(1))) {
//      logger.error("validate fail, params type not equal with each other");
//      return null;
//    }
        JSONArray input = conf.getJSONArray("input");
        if (input.size() < 2) {
            logger.error("input table is less  than 2");
            return null;
        }
        if (rules.size() == 0) {
            return StringUtils.EMPTY;
        }
        String tableLeft = input.getJSONObject(0).getString("tableName");
        tableLeft = ToolUtil.alignTableName(tableLeft, timeStamp);
        String tableRight = input.getJSONObject(1).getString("tableName");
        tableRight = ToolUtil.alignTableName(tableRight, timeStamp);
        List<String> leftFields = input.getJSONObject(0).getJSONArray("tableCols")
            .toJavaList(String.class);
        if (leftFields.contains("_record_id_")){
            leftFields.remove("_record_id_");
        }
        List<String> rightFields = input.getJSONObject(1).getJSONArray("tableCols")
            .toJavaList(String.class);
        if (rightFields.contains("_record_id_")){
            rightFields.remove("_record_id_");
        }
        List<String> leftFieldKeys = SqlTemplate.getColumns(leftFields);
        List<String> rightFieldKeys = SqlTemplate.getColumns(rightFields);
        leftFields = SqlTemplate.columnsRenameSql(leftFields, "a", null, rightFieldKeys);
        rightFields = SqlTemplate.columnsRenameSql(rightFields, "b", null, leftFieldKeys);
        ToolUtil.startsWithFilter(leftFields, "_record_id__");
        ToolUtil.startsWithFilter(rightFields, "_record_id__");
        //_record_id_放在首位，gp默认用第一个字段作为分布键
        String selectColumns = String
            .format("%s,%s ", Joiner.on(", ").join(leftFields), Joiner.on(", ").join(rightFields));
        String selectLeftColumns = Joiner.on(", ").join(leftFields);
        String selectRightColumns = Joiner.on(", ").join(rightFields);
        String selectSql = "";

        String mode = "";
        String whereStr = "";
        switch (mod) {
            case 1: {
                mode = JoinType.LEFT_JOIN.getDesc();
                break;
            }
            case 2: {
                mode = JoinType.RIGHT_JOIN.getDesc();
                break;
            }
            case 3: {
                mode = JoinType.JOIN.getDesc();
                break;
            }
            case 4: {
                mode = JoinType.OUTER_JOIN.getDesc();
                break;
            }
            case 5: {
                mode = JoinType.LEFT_NOT_JOIN.getDesc();
                selectColumns = selectLeftColumns;
                whereStr = " where b._record_id_ is null";
                break;
            }
            case 6: {
                mode = JoinType.RIGHT_NOT_JOIN.getDesc();
                selectColumns = selectRightColumns;
                whereStr = " where a._record_id_ is null";
                break;
            }
            case 7: {
                mode = JoinType.OUTER_NOT_JOIN.getDesc();
                whereStr = " where a._record_id_ is null or b._record_id_ is null";
                break;
            }
            default:
                break;
        }

        String conditionStr = buildConditions();
        conditionStr += whereStr;

        if (!conf.containsKey("isSample") || conf.getString("isSample").equals("SUCCESS") || conf
            .getString("isSample").equals("FAIL")) {
            selectSql = String
                .format(JOIN_SQL_SAMPLE, selectColumns, tableLeft, mode, tableRight, conditionStr,
                    SAMPLE_NUMBER);
            conf.put("isSample", "CREATE");
        } else {
            selectSql = String
                .format(JOIN_SQL, selectColumns, tableLeft, mode, tableRight, conditionStr);
        }

        JSONArray output = conf.getJSONArray("output");
        String outTable = output.getJSONObject(0).getString("tableName");
        outTable = ToolUtil.alignTableName(outTable, timeStamp);
        return String.format(SqlTemplate.CREATE_TABLE_SQL, outTable, selectSql);
    }


    private void sortInput(JSONObject jsonObject, JSONArray input) {
        try {
            for (int i = 0; i < 2; ++i) {
                List<String> cols = input.getJSONObject(i).getJSONArray("tableCols")
                        .toJavaList(String.class);
                List<String> types = input.getJSONObject(i).getJSONArray("columnTypes")
                        .toJavaList(String.class);
                int index = cols.indexOf("_record_id_");
                cols.remove(index);
                types.remove(index);
                ToolUtil.startsWithFilter(cols, "_record_id_");
                ToolUtil.listWithIndex(cols);
                ToolUtil.sortList(cols);
                List<String> sortCols = new ArrayList<>();
                List<String> sortTypes = new ArrayList<>();
                ToolUtil.getColAndType(cols, sortCols, types, sortTypes);
                JSONObject item = input.getJSONObject(i);
                item.put("tableCols", sortCols);
                item.put("columnTypes", sortTypes);
                input.set(i, item);
            }
            jsonObject.put("input", input);
        }catch (Exception e){
            throw DataScienceException.of(BaseErrorCode.TASK_NOT_PROPER_INIT, "字段获取失败");
        }
    }

    public void defineOutput(TaskVO vo) {
        try {
            JSONObject jsonObject = vo.getData();
            JSONArray output = new JSONArray();
            JSONObject item = new JSONObject();
            String tableName = String
                    .format(SqlTemplate.OUT_TABLE_NAME, "join", vo.getPipelineId(), vo.getId());
            JSONArray outputColumnTypes = new JSONArray();
            JSONArray input = jsonObject.getJSONArray("input");
            if (input.size() < 2) {
                logger.debug("input size is less 2");
                return;
            }
            sortInput(jsonObject, input);
            JSONObject input1 = input.getJSONObject(0);
            JSONObject input2 = input.getJSONObject(1);
            List<String> leftFields = input1.getJSONArray("tableCols").toJavaList(String.class);
            if (leftFields.contains("_record_id_")){
                leftFields.remove("_record_id_");
            }
            List<String> leftFieldKeys = SqlTemplate.getColumns(leftFields);
            JSONObject leftSemantic = input1.getJSONObject("semantic");
            JSONObject leftCategoryOrder = input1.getJSONObject("categoryOrder");
            // SqlTemplate.startsWithFilter(leftFields, "_record_id__");
            List<String> rightFields = input2.getJSONArray("tableCols")
                    .toJavaList(String.class);
            if (rightFields.contains("_record_id_")){
                rightFields.remove("_record_id_");
            }
            List<String> rightFieldKeys = SqlTemplate.getColumns(rightFields);
            JSONObject rightSemantic = input2.getJSONObject("semantic");
            JSONObject rightCategoryOrder = input2.getJSONObject("categoryOrder");

            JSONObject leftNumberFormat = input1.getJSONObject("numberFormat");
            if (leftNumberFormat == null) {
                leftNumberFormat = new JSONObject();
            }
            if (leftNumberFormat != null && leftFields != null && !leftFields.isEmpty()) {
                Set<String> cols = leftNumberFormat.keySet();
                for (String col : cols) {
                    if (!leftFields.contains(col)) {
                        leftNumberFormat.remove(col);
                    }
                }
            }
            JSONObject rightNumberFormat = input2.getJSONObject("numberFormat");
            if (rightNumberFormat == null) {
                rightNumberFormat = new JSONObject();
            }
            if (rightNumberFormat != null && rightFields != null && !rightFields.isEmpty()) {
                Set<String> cols = rightNumberFormat.keySet();
                for (String col : cols) {
                    if (!rightFields.contains(col)) {
                        rightNumberFormat.remove(col);
                    }
                }
            }
            leftFields = SqlTemplate
                    .columnInfoRename(leftFields, "a", leftSemantic, leftNumberFormat, leftCategoryOrder, rightFieldKeys);
            rightFields = SqlTemplate
                    .columnInfoRename(rightFields, "b", rightSemantic, rightNumberFormat, rightCategoryOrder, leftFieldKeys);
            // SqlTemplate.startsWithFilter(rightFields, "_record_id__");
            List<String> leftFieldTypes = input1.getJSONArray("columnTypes")
                    .toJavaList(String.class);
            List<String> rightFieldTypes = input2.getJSONArray("columnTypes")
                    .toJavaList(String.class);
            outputColumnTypes.addAll(leftFieldTypes);
            outputColumnTypes.addAll(rightFieldTypes);
            outputColumnTypes.add("BIGINT");
            JSONArray cols = new JSONArray();
            cols.addAll(leftFields);
            cols.addAll(rightFields);
            JSONObject newSemantic = new JSONObject();
            newSemantic.putAll(leftSemantic);
            newSemantic.putAll(rightSemantic);
            JSONObject newCategoryOrder = new JSONObject();
            if (leftCategoryOrder != null) {
                newCategoryOrder.putAll(leftCategoryOrder);
            }
            if (rightCategoryOrder != null) {
                newCategoryOrder.putAll(rightCategoryOrder);
            }
            JSONObject newNumberFormat = new JSONObject();
            newNumberFormat.putAll(leftNumberFormat);
            newNumberFormat.putAll(rightNumberFormat);
            cols.add("_record_id_");
            item.put("numberFormat", newNumberFormat);
            item.put("tableName", tableName);
            item.put("tableCols", cols);
            item.put("nodeName", vo.getName() == null ? ETLEnum.JOIN.toString() : vo.getName());
            item.put("columnTypes", outputColumnTypes);
            item.put("semantic", newSemantic);
            if (newCategoryOrder.size() > 0) {
                item.put("categoryOrder", newCategoryOrder);
            }
            this.setSubTypeForOutput(item);
            output.add(item);
            jsonObject.put("output", output);
            vo.setData(jsonObject);
        }catch (DataScienceException e){
            logger.error("error happened when defining task output, since {}", e.getMessage());
            vo.setException(e);
        }
    }

    public void initTemplate(JSONObject data) {
        data.put("mod", 3);
        data.put("conditions", rules);
        baseInitTemplate(data);
    }

    private boolean validate(SqlHelper sqlHelperLeft, SqlHelper sqlHelperRight) {
        boolean isValid = true;
        if (sqlHelperLeft.getColumnType(left) != sqlHelperRight.getColumnType(right)) {
            isValid = false;
        }
        return isValid;
    }
}
